Predict the future? Let’s just see if we can predict what happened last year!
As you know, Facebook’s prophet model is easy to use and works well with clear seasonality
But it doesn’t have any built-in way to deal with confounders
This is a major limitation if you want a short-term, potentially high-variance forecast
We also know that interactions exist in most data sets–prophet is good at fitting curves, but is walking past $100 bill covariates
Let’s simulate some data
Let’s assume we have pretty predictable seasonal trends in subscribers.
Expand for code
# load librarieslibrary(tidyverse)library(lubridate)library(usaidplot)library(ggridges)library(prophet)library(bsts)library(gt)library(gtExtras)# set seedset.seed(6292025)# Simulate weekly datesn_weeks <-156weeks <-seq.Date(from =as.Date("2022-01-02"), by ="week", length.out = n_weeks)# Categoriessports <-c("baseball", "basketball", "football", "soccer")sport_category <-c("school", "club", "recreation")age_group <-c("U13", "13_18", "18plus")# Seasonal trend function with peaksseasonal_trend <-function(week_date) { week_of_year <-isoweek(week_date) spring_peak <-exp(-0.03* (week_of_year -18)^2) *800 fall_bump <-exp(-0.02* (week_of_year -42)^2) *200500+ spring_peak + fall_bump}# Simulate Data with strong correlation for baseball U13sim_data <-map_dfr(weeks, function(week) { base_n <-round(seasonal_trend(week) +rnorm(1, 0, 50))if (base_n <=0) return(tibble())tibble(sport =sample(sports, base_n, replace =TRUE, prob =c(0.5, 0.2, 0.15, 0.15)),category =sample(sport_category, base_n, replace =TRUE),age_group =sample(age_group, base_n, replace =TRUE, prob =c(0.5, 0.3, 0.2)),tenure_weeks =rpois(base_n, lambda =12) ) |>mutate(week = week,subscribe_prob =case_when( sport =="baseball"& age_group =="U13"~0.95 , sport =="baseball"~0.7 , age_group =="U13"~0.6 , TRUE~0.4 ),is_subscribed =rbinom(n(), 1, subscribe_prob) ) |>filter(is_subscribed ==1) |>select(week, sport, category, age_group, tenure_weeks)})# Aggregate to weekly subscriber counts by groupsummary_data <- sim_data |>count(week, sport, age_group) |>pivot_wider(names_from =c(sport, age_group), values_from = n, values_fill =0)# Correct row sum using select and across properlytotal_cols <- summary_data |>select(-week) |>colnames()summary_data <- summary_data |>mutate(total_subs =rowSums(across(all_of(total_cols))))ggplot(summary_data, aes(x = week, y = total_subs)) +geom_line(color ="#5F9EA0", linewidth =1) +labs(title ="Simulated Subscribers with Seasonal Peaks", x ="", y =" ") +usaid_plot()
Let’s also add in some confounding
Our simulated data also include higher variation for baseball subscribers, particularly for the Under 13 age group.
Expand for code
sim_data |>count(week, sport, age_group) |>mutate(age_group =case_when(age_group =="U13"~"Under 13" , age_group =="13_18"~"13-18" , age_group =="18plus"~"18+" ) , sport = Hmisc::capitalize(sport) ) |>mutate(age_group =factor(age_group, levels =c("Under 13", "13-18", "18+"))) |>ggplot() +geom_density_ridges(aes(x = n, y =fct_rev(age_group), fill = sport), alpha =0.6) +usaid_plot() +scale_fill_viridis_d() +theme(legend.position ="top") +labs(x ="", y ="" , title ="Distribution of Subscribers by Age Group and Sport")
Forecasting
Prophet does well on our simulated data even if we know there is some predictive value in age group and sport.
Expand for code
# create a train and test split where we hold out the last year for forecast testingtrain_data <- summary_data |>filter(week <=max(week) -weeks(52))test_data <- summary_data |>filter(week >max(week) -weeks(52))# fit our prophet modelprophet_data <- summary_data |>select(ds = week, y = total_subs)prophet_model <-prophet(prophet_data)future <-make_future_dataframe(prophet_model, periods =52, freq ="week")forecast <-predict(prophet_model, future) |>mutate(ds =as_date(ds))summary_data |>select(week, total_subs) |>left_join(forecast |>filter(year(ds)>2023) |>select("week"= ds, yhat, yhat_lower, yhat_upper)) |>ggplot(aes(x = week)) +geom_line(aes(y = total_subs), col ="#440154FF") +geom_line(aes(y = yhat), col ="#22A884FF", linetype ="dashed") +geom_ribbon(aes(y = yhat, ymin = yhat_lower, ymax = yhat_upper), fill ="#22A884FF", alpha =0.3) +usaid_plot() +labs(x ="", y ="Total Subscribers", title ="Prophet Model Against Acutals for the last year in data")
Prophet’s performance stats are really strong
Even without accounting for covariates, Prophet does well.
Expand for code
df_cv <-cross_validation(prophet_model, initial =100, period =1, horizon =52, units ='weeks')df_perf <-tibble(performance_metrics(df_cv))df_perf |>ggplot(aes(x = horizon, y = rmse)) +geom_line(color ="#22A884FF", size =1) +geom_hline(yintercept =mean(df_perf$rmse), linetype ="dashed") +usaid_plot() +labs(x ="Number of forecast days", y ="RMSE", title ="Prophet's error is low even for dates a year out" , subtitle ="Indeed, so low, I kind of worry about overfitting" , caption ="Dashed line shows mean RMSE")
Similar to Prophet in that it decomposes time series into different elements
BUT unlike Prophet, it allows for incorporation of covariates and lets users understand the relationships between different variables and the outcome of interest
This last point is important for product segmentation and for understanding why a forecasted trend is expected
Let’s look under the hood at BSTS and then show some results!
BSTS uses “states” to forecast. There are various states that are summed to actually estimate the forecast, but the main ones are:
Trend: captures the overall direction
Seasonal: captures repeating patterns
Regression: captures covariates, important for shocks/outliers
From Scott and Varian 2013, Google
Let’s look under the hood at BSTS and then show some results!
Expand for code
# let's show an example of spike and slabtibble(x=abs(rnorm(1000, mean =2, 0.5))) |>ggplot(aes(x)) +geom_density(lwd =1, fill ="#FDE725FF", alpha =0.4) +geom_vline(xintercept =0, lwd =1, linetype ="dashed") +usaid_plot() +xlim(-1, 4) +geom_text(data =data.frame(x =c(-0.151347523196993, 1.83561826421733 ),y =c(0.629476876346744, 0.551140209769703),label =c("Spike", "Slab")),mapping =aes(x = x, y = y, label = label),family ="Gill Sans", inherit.aes =FALSE, size =5) +labs(x ="", y ="")
Both Prophet and BSTS decompose time series into parts like trend and seasonality.
Prophet assumes fixed, repetition over a trend with change points made explicit, while BSTS states adapt over time based on data
BSTS uses spike and slab priors in the regression component–this helps identify confounders and avoid overfitting
Forecasting with BSTS
We get similar results but with less uncertainty around the forecast compared to our prophet model.
Expand for code
# Split training and testing (last year weeks as holdout)train_data <- summary_data |>filter(week <=max(week) -weeks(52))test_data <- summary_data |>filter(week >max(week) -weeks(52))# Set up BSTSss <-AddLocalLevel(list(), train_data$total_subs)ss <-AddSeasonal(ss, train_data$total_subs, nseasons =52)bsts_model <-bsts(total_subs ~ . - week, state.specification = ss, niter =2000, data = train_data, ping =-1)# predict last 52 weeks using true covariatestest_cov <- test_data |>select(-total_subs)bsts_forecast <-predict(bsts_model, horizon =52, newdata = test_cov)bsts_df <-tibble(mean = bsts_forecast$mean , lower = bsts_forecast$interval[1, ] , upper = bsts_forecast$interval[2, ] , week = test_data$week)ggplot(summary_data, aes(x = week)) +geom_line(aes(y = total_subs), col ="#440154FF") +geom_line(data =bsts_df, aes(y = mean), col ="#22A884FF", linetype ="dashed") +geom_ribbon( data = bsts_df, aes(y = mean, ymin = lower, ymax = upper), fill ="#22A884FF", alpha =0.3) +usaid_plot() +labs(x ="", y ="Total Subscribers", title ="BSTS Model Against Acutals for the last year in data" )
BSTS Performance
Expand for code
horizon_weeks <-52initial_window <-104period <-1niter <-1500# creat results Holderrmse_results <-list()# Loop through rolling forecast originstotal_obs <-nrow(summary_data)for (start_idx inseq(initial_window, total_obs - horizon_weeks, by = period)) {# split train/test sets train_data <- summary_data |>slice(1:start_idx) test_data <- summary_data |>slice((start_idx +1):(start_idx + horizon_weeks))# Specify BSTS state components same as above ss <-AddLocalLevel(list(), train_data$total_subs) ss <-AddSeasonal(ss, train_data$total_subs, nseasons =52)# Fit BSTS model bsts_fit <-bsts(total_subs ~ . - week, state.specification = ss, niter = niter, data = train_data, ping =-1)# Forecast with covariates new_covariates <- test_data |>select(-total_subs) forecast_out <-predict(bsts_fit, horizon = horizon_weeks, newdata = new_covariates)# Calculate RMSE per horizon rmse_vec <-sqrt((forecast_out$mean - test_data$total_subs)^2) rmse_results[[length(rmse_results) +1]] <-tibble(horizon =1:horizon_weeks,rmse = rmse_vec )}# Aggregate RMSE by horizon across foldsrmse_summary <-bind_rows(rmse_results) |>summarise(avg_rmse =mean(rmse, na.rm =TRUE),.by = horizon )# Graph itggplot(rmse_summary, aes(x = horizon, y = avg_rmse)) +geom_line(color ="#22A884FF", size =1) +geom_hline(yintercept =mean(rmse_summary$avg_rmse), linetype ="dashed") +usaid_plot() +labs(x ="Number of forecast days", y ="RMSE", title ="BSTS's error is low even for dates a year out" , subtitle ="While overfitting may still be an issue,\nwe know we've minimized it through BSTS priors" , caption ="Dashed line shows mean RMSE")
Checking out the BSTS components
Expand for code
# this bit is taken from, same with the other figure in the next chunk https://multithreaded.stitchfix.com/blog/2016/04/21/forget-arima/# define a burn, i.e., a percentage of runs to discardburn <-SuggestBurn(0.1, bsts_model)components_withreg <-cbind.data.frame(colMeans(bsts_model$state.contributions[-(1:burn),"trend",]),colMeans(bsts_model$state.contributions[-(1:burn),"seasonal.52.1",]),colMeans(bsts_model$state.contributions[-(1:burn),"regression",]),as.Date(time(train_data$week))) names(components_withreg) <-c("Trend", "Seasonality", "Regression", "Date")components_withreg <- reshape2::melt(components_withreg, id.vars="Date")names(components_withreg) <-c("Date", "Component", "Value")ggplot(data=components_withreg, aes(x=Date, y=Value)) +geom_line(aes(color = Component)) +usaid_plot() +theme(legend.title =element_blank()) +labs(x ="", y ="") +facet_grid(Component ~ ., scales="free")
Expand for code
inclusionprobs <- reshape2::melt(colMeans(bsts_model$coefficients[-(1:burn),] !=0))inclusionprobs$Variable <-as.character(row.names(inclusionprobs))ggplot(data=inclusionprobs, aes(x=reorder(Variable, value), y=value)) +geom_bar(stat="identity", position="identity") +usaid_plot() +theme(axis.text.x=element_text(angle =-90, hjust =0)) +labs(x ="", y ="", title ="Inclusion probabilities" , subtitle ="Since we do 2,000 MCMC runs,\nwe can see what covariates are most likely to be included\nand which get regularized away")